package model

import (
	"database/sql"
	"fmt"
	"log"
)

// StockData represents the structure of the stock data table
type StockDataShowApi struct {
	ID         int     `json:"id"`          // 自增主键，唯一标识每条记录
	StockName  string  `json:"stock_name"`  // 股票名称
	TradeMoney float64 `json:"trade_money"` // 交易金额
	DiffMoney  float64 `json:"diff_money"`  // 价格差异金额
	OpenPrice  float64 `json:"open_price"`  // 开盘价
	StockCode  string  `json:"stock_code"`  // 股票代码
	Date       string  `json:"date"`        // 交易日期
	MinPrice   float64 `json:"min_price"`   // 最低价
	Market     string  `json:"market"`      // 市场（例如，`sh` 为上海，`sz` 为深圳）
	TradeNum   int     `json:"trade_num"`   // 交易数量
	ClosePrice float64 `json:"close_price"` // 收盘价
	MaxPrice   float64 `json:"max_price"`   // 最高价
	Swing      float64 `json:"swing"`       // 振幅
	DiffRate   float64 `json:"diff_rate"`   // 涨跌幅度
	Turnover   float64 `json:"turnover"`    // 换手率
}

// Save method to insert or update a stock record in the database
func (s *StockDataShowApi) SaveStockSApi() error {
	// Check if the stock with the same stock_code exists in the database
	var count int
	err := db.QueryRow("SELECT COUNT(*) FROM stock_data WHERE stock_code = ? AND date = ?", s.StockCode, s.Date).Scan(&count)
	if err != nil {
		log.Printf("Failed to check if stock exists: %v", err)
		return err
	}

	// Prepare query based on whether the stock exists
	var query string
	if count > 0 {
		query = `
			UPDATE stock_data
			SET stock_name = ?, trade_money = ?, diff_money = ?, open_price = ?, min_price = ?, market = ?, trade_num = ?, close_price = ?, max_price = ?, swing = ?, diff_rate = ?, turnover = ? 
			WHERE stock_code = ? AND date = ?
		`
	} else {
		query = `
			INSERT INTO stock_data (stock_name, trade_money, diff_money, open_price, stock_code, date, min_price, market, trade_num, close_price, max_price, swing, diff_rate, turnover)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
		`
	}

	// Execute the query
	var res sql.Result
	if count > 0 {
		res, err = db.Exec(query, s.StockName, s.TradeMoney, s.DiffMoney, s.OpenPrice, s.MinPrice, s.Market, s.TradeNum, s.ClosePrice, s.MaxPrice, s.Swing, s.DiffRate, s.Turnover, s.StockCode, s.Date)
	} else {
		res, err = db.Exec(query, s.StockName, s.TradeMoney, s.DiffMoney, s.OpenPrice, s.StockCode, s.Date, s.MinPrice, s.Market, s.TradeNum, s.ClosePrice, s.MaxPrice, s.Swing, s.DiffRate, s.Turnover)
	}
	if err != nil {
		log.Printf("Error saving stock: %v", err)
		return err
	}

	// Check the number of rows affected to determine if it was an update or insert
	rowsAffected, err := res.RowsAffected()
	if err != nil {
		log.Printf("Error getting rows affected: %v", err)
		return err
	}
	if count > 0 && rowsAffected > 0 {
		// log.Printf("Stock with stock_code %s updated successfully", s.StockCode)
	} else if rowsAffected > 0 {
		// log.Printf("Stock with stock_code %s inserted successfully", s.StockCode)
	}

	return nil
}

// BatchInsert inserts multiple stock data records into the database
func BatchInsert(stocks []StockDataShowApi) error {
	if len(stocks) == 0 {
		return nil
	}

	query := `
		INSERT INTO stock_data (stock_name, trade_money, diff_money, open_price, stock_code, date, min_price, market, trade_num, close_price, max_price, swing, diff_rate, turnover)
		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
	`

	tx, err := db.Begin()
	if err != nil {
		return err
	}
	defer tx.Rollback()

	stmt, err := tx.Prepare(query)
	if err != nil {
		return err
	}
	defer stmt.Close()

	for _, stock := range stocks {
		_, err := stmt.Exec(
			stock.StockName, stock.TradeMoney, stock.DiffMoney, stock.OpenPrice, stock.StockCode,
			stock.Date, stock.MinPrice, stock.Market, stock.TradeNum,
			stock.ClosePrice, stock.MaxPrice, stock.Swing, stock.DiffRate, stock.Turnover,
		)
		if err != nil {
			return err
		}
	}

	return tx.Commit()
}

// GetStockListPage retrieves a paginated list of stock data records
func GetStockSAListPage(page int, pageSize int, orderBy string, orderDirection string, stockCode string) ([]StockDataShowApi, error) {
	offset := (page - 1) * pageSize

	query := fmt.Sprintf(`
		SELECT * from (
		SELECT id, stock_name, trade_money, diff_money, open_price, stock_code, date, min_price, market, trade_num, close_price, max_price, swing, diff_rate, turnover
		FROM stock_data WHERE stock_code = '%s'
		ORDER BY %s %s  
		LIMIT %d OFFSET %d
		) e ORDER BY date asc
	`, stockCode, orderBy, orderDirection, pageSize, offset)

	rows, err := db.Query(query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var stocks []StockDataShowApi

	for rows.Next() {
		var stock StockDataShowApi

		err := rows.Scan(
			&stock.ID, &stock.StockName, &stock.TradeMoney, &stock.DiffMoney, &stock.OpenPrice, &stock.StockCode,
			&stock.Date, &stock.MinPrice, &stock.Market, &stock.TradeNum, &stock.ClosePrice, &stock.MaxPrice,
			&stock.Swing, &stock.DiffRate, &stock.Turnover,
		)
		if err != nil {
			return nil, err
		}

		stocks = append(stocks, stock)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return stocks, nil
}
